{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "from path_explain import utils\n",
    "utils.set_up_environment(visible_devices='0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_datasets\n",
    "import numpy as np\n",
    "import scipy\n",
    "from transformers import *\n",
    "from plot.text import text_plot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data and Model Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = 'sst-2'\n",
    "num_labels = len(glue_processors[task]().get_labels())\n",
    "tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "large_config = DistilBertConfig.from_pretrained('distilbert-base-uncased', num_labels=num_labels)\n",
    "large_model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', config=large_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = DistilBertConfig.from_pretrained('distilbert-base-uncased', num_labels=num_labels, dim=72)\n",
    "model = TFDistilBertForSequenceClassification(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=13-8)\n",
    "loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
    "metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')\n",
    "model.compile(optimizer=opt, loss=loss, metrics=[metric])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:absl:Overwrite dataset info from restored data version.\n",
      "INFO:absl:Field info.description from disk and from code do not match. Keeping the one from code.\n",
      "INFO:absl:Field info.location from disk and from code do not match. Keeping the one from code.\n",
      "INFO:absl:Reusing dataset glue (/homes/gws/psturm/tensorflow_datasets/glue/sst2/0.0.2)\n",
      "INFO:absl:Constructing tf.data.Dataset for split None, from /homes/gws/psturm/tensorflow_datasets/glue/sst2/0.0.2\n"
     ]
    }
   ],
   "source": [
    "data, info = tensorflow_datasets.load('glue/sst2', with_info=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = glue_convert_examples_to_features(data['train'], tokenizer, max_length=128, task=task)\n",
    "valid_dataset = glue_convert_examples_to_features(data['validation'], tokenizer, max_length=128, task=task)\n",
    "valid_dataset = valid_dataset.batch(16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-0.02245433, -0.00332783],\n",
       "       [-0.02258434, -0.00324947],\n",
       "       [-0.02278818, -0.00320506],\n",
       "       ...,\n",
       "       [-0.02278283, -0.00315616],\n",
       "       [-0.02254031, -0.00312885],\n",
       "       [-0.02276037, -0.00321589]], dtype=float32)"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_ = model.predict(valid_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"tf_distil_bert_for_sequence_classification_6\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "distilbert (TFDistilBertMain multiple                  5035536   \n",
      "_________________________________________________________________\n",
      "pre_classifier (Dense)       multiple                  5256      \n",
      "_________________________________________________________________\n",
      "classifier (Dense)           multiple                  146       \n",
      "_________________________________________________________________\n",
      "dropout_121 (Dropout)        multiple                  0         \n",
      "=================================================================\n",
      "Total params: 5,040,938\n",
      "Trainable params: 5,040,938\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight_list = []\n",
    "\n",
    "for i in range(len(large_model.layers[0].weights)):\n",
    "    large_weight = large_model.layers[0].weights[i]\n",
    "    slice_indices = []\n",
    "    for s in large_weight.get_shape().as_list():\n",
    "        if s == 768:\n",
    "            slice_indices.append(slice(0, 72))\n",
    "        else:\n",
    "            slice_indices.append(slice(None))\n",
    "    slice_indices = tuple(slice_indices)\n",
    "    sliced_large_weight = large_weight[slice_indices]\n",
    "    weight_list.append(sliced_large_weight)\n",
    "\n",
    "model.layers[0].set_weights(weight_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'valid_dataset' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-1-c7e642dd4a5b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mitem\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mvalid_dataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'valid_dataset' is not defined"
     ]
    }
   ],
   "source": [
    "for item in valid_dataset.take(1):\n",
    "    print(item)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
